package edu.northwestern.cbits.purple_robot_manager.models; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import org.json.JSONArray; import org.json.JSONException; import android.content.Context; import android.net.Uri; import edu.northwestern.cbits.purple_robot_manager.R; import edu.northwestern.cbits.purple_robot_manager.logging.LogManager; import edu.northwestern.cbits.purple_robot_manager.models.trees.LeafNode; import edu.northwestern.cbits.purple_robot_manager.models.trees.TreeNode; import edu.northwestern.cbits.purple_robot_manager.models.trees.TreeNode.TreeNodeException; import edu.northwestern.cbits.purple_robot_manager.models.trees.parsers.TreeNodeParser; import edu.northwestern.cbits.purple_robot_manager.models.trees.parsers.TreeNodeParser.ParserNotFound; public class MatlabForestModel extends WekaTreeModel { public static final String TYPE = "matlab-forest"; private static final String VOTES = "VOTES"; private static final String TREE_COUNT = "TOTAL_VOTERS"; private ArrayList<TreeNode> _trees = new ArrayList<>(); public MatlabForestModel(Context context, Uri uri) { super(context, uri); } protected void generateModel(Context context, Object model) { synchronized (this) { if (model instanceof JSONArray) { JSONArray modelArray = (JSONArray) model; for (int i = 0; i < modelArray.length(); i++) { try { Object modelItem = modelArray.get(i); if (modelItem instanceof String) { try { TreeNode tree = TreeNodeParser.parseString(modelItem.toString()); this._trees.add(tree); } catch (ParserNotFound | TreeNodeException e) { LogManager.getInstance(context).logException(e); } } } catch (JSONException e) { LogManager.getInstance(context).logException(e); } } } } } protected Object evaluateModel(Context context, Map<String, Object> snapshot) { String maxPrediction = null; int maxCount = -1; synchronized (this) { Map<String, Integer> counts = new HashMap<>(); for (TreeNode tree : this._trees) { try { Map<String, Object> prediction = tree.fetchPrediction(snapshot); String treePrediction = prediction.get(LeafNode.PREDICTION).toString(); Integer count = 0; if (counts.containsKey(treePrediction)) count = counts.get(treePrediction); count = count.intValue() + 1; counts.put(treePrediction.toString(), count); } catch (TreeNode.TreeNodeException e) { // e.printStackTrace(); } catch (Exception e) { LogManager.getInstance(context).logException(e); } } for (String prediction : counts.keySet()) { Integer count = counts.get(prediction); if (count > maxCount) { maxCount = count; maxPrediction = prediction; } } } HashMap<String, Object> prediction = new HashMap<>(); prediction.put(LeafNode.PREDICTION, maxPrediction); prediction.put(LeafNode.ACCURACY, (double) maxCount / (double) this._trees.size()); prediction.put(MatlabForestModel.VOTES, maxCount); prediction.put(MatlabForestModel.TREE_COUNT, this._trees.size()); return prediction; } public String summary(Context context) { return context.getString(R.string.summary_model_forest); } public String modelType() { return MatlabForestModel.TYPE; } }